import torch
import numpy as np
from sklearn import neighbors


class FPGReward:
    def __init__(self, obj):
        self.agent_density = None
        self.expert_density = None
        self.obj = obj

    def get_scalar_reward(self, state):
        # rho_expert_samples = Torch.FloatTensor(rho_expert(s_vec))
        # rho_expert_samples = rho_expert_samples.clip(low=1e-3)
        # if isinstance(self.expert_density, neighbors._kde.KernelDensity):
        # rho_expert_samples = np.clip(np.exp(self.expert_density(s_vec)), a_min=3e-2, a_max=None)
        # else:
        # print('In get_scalar_reward: ', self.expert_density(s_vec))
        rho_expert_samples = np.clip(np.exp(self.expert_density.score_samples(state)), a_min=3e-2, a_max=None)
        agent_density_samples = np.clip(np.exp(self.agent_density.score_samples(state)), a_min=1e-2, a_max=None)
        # print(rho_expert_samples)
        # log_density_ratio = np.log(rho_expert(s_vec)) - agent_density.score_samples(s_vec).reshape(-1)
        # log_density_ratio = torch.FloatTensor(log_density_ratio).to(device)
        # reward = torch.exp(log_density_ratio) * 30
        reward = rho_expert_samples/agent_density_samples
        if self.obj == 'fkl':
            return np.log(reward)
        elif self.obj == 'rkl':
            return reward
        elif self.obj == 'js':
            return np.log((1 + (1.0/reward))/(2 * (1.0/reward)))
        elif self.obj == 'chi2':
            return -1.0/(reward)

    def update(self, agent_density, expert_density):
        self.agent_density = agent_density
        self.expert_density = expert_density